# -*- coding: utf-8 -*-
"""Tesi.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1CGTUPViBEulk5VMleO0VANZF0-Pl5jdQ
"""

import os
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
import numpy as np
from tensorflow.keras.models import Model
from sklearn.model_selection import StratifiedKFold
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Input
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.callbacks import LearningRateScheduler, EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from google.colab import drive
drive.mount("/content/drive")

!ls /content/drive/MyDrive/dataset-resized/

dataset_path = '/content/drive/MyDrive/dataset-resized/train/'
dataset_val='/content/drive/MyDrive/dataset-resized/test/'

if not os.path.exists(dataset_path):
    print(f"Training data path does not exist: {dataset_path}")
else:
    print(f"Training data path exists: {dataset_path}")

IMAGE_WIDTH = 150
IMAGE_HEIGHT = 150
IMAGE_CHANNELS = 3
BATCH_SIZE = 64
EPOCHS = 50
TRAIN_DATA_PATH = dataset_path

def load_and_preprocess_image(path):
    try:
        image = tf.io.read_file(path)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.image.resize(image, [IMAGE_HEIGHT, IMAGE_WIDTH])
        image = image / 255.0  # Normalize to [0, 1]
        return image
    except tf.errors.NotFoundError:
        print("File not found:", path)
        return None


def get_filenames_and_labels(directory):
    filenames = []
    labels = []
    class_names = sorted(os.listdir(directory)) [1:] # Make sure directory only contains valid class folders
    label_dict = {name: index for index, name in enumerate(class_names)}

    for label_name in class_names:
        class_dir = os.path.join(directory, label_name)
        class_files = [os.path.join(class_dir, name) for name in os.listdir(class_dir) if
                       name.endswith(('.png', '.jpg', '.jpeg'))]  # Make sure to filter only image files
        filenames.extend(class_files)
        labels.extend([label_dict[label_name]] * len(class_files))

    return filenames, labels, class_names

# Function for data augmentation
def augment_image(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_brightness(image, max_delta=0.2)
    image = tf.image.random_contrast(image, lower=0.95, upper=1.05)
    image = tf.image.random_saturation(image, lower=0.95, upper=1.05)
    image = tf.clip_by_value(image, 0.0, 1.0)
    return image, label

def create_datasets_for_fold(train_files, train_labels, val_files, val_labels, class_names):
    # Creating file datasets for both training and validation
    train_data = tf.data.Dataset.from_tensor_slices((train_files, train_labels))
    val_data = tf.data.Dataset.from_tensor_slices((val_files, val_labels))

    # Apply the original function to prepare datasets
    train_dataset = train_data.map(lambda x, y: (load_and_preprocess_image(x), tf.one_hot(y, depth=len(class_names))))
    validation_dataset = val_data.map(lambda x, y: (load_and_preprocess_image(x), tf.one_hot(y, depth=len(class_names))))

    #Create augmented datasets
    augmented_datasets = [train_dataset.map(augment_image) for _ in range(18)]  # Create 18 augmented datasets for train
    augmented_datasets_val = [validation_dataset.map(augment_image) for _ in range(18)]

    # Concatenate original and augmented datasets
    full_train_dataset = train_dataset.concatenate(augmented_datasets[0])
    for aug_dataset in augmented_datasets[1:]:
        full_train_dataset = full_train_dataset.concatenate(aug_dataset)
    full_val_dataset = validation_dataset.concatenate(augmented_datasets_val[0])
    for aug_dataset in augmented_datasets_val[1:]:
        full_val_dataset = full_val_dataset.concatenate(aug_dataset)

    # Batch the datasets
    train_dataset = full_train_dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    validation_dataset = full_val_dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    return train_dataset, validation_dataset

def build_model(num_classes):
    model = tf.keras.Sequential([
        Input(shape=(IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNELS)),
        Conv2D(32, (3, 3), activation='relu'),
        MaxPooling2D(2, 2),
        Dropout(0.35),
        Conv2D(64, (3, 3), activation='relu'),
        MaxPooling2D(2, 2),
        Dropout(0.35),
        Conv2D(64, (3, 3), activation='relu'),
        MaxPooling2D(2, 2),
        Dropout(0.35),
        Conv2D(128,(3,3), activation='relu'),
        MaxPooling2D(2,2),
        Dropout(0.35),
        Conv2D(128,(3,3), activation='relu'),
        MaxPooling2D(2,2),
        Dropout(0.35),
        Flatten(),
        Dense(512, activation='relu'),
        Dropout(0.8),
        Dense(num_classes, activation='softmax')
    ])

    initial_learning_rate = 0.0001
    lr_schedule = ExponentialDecay(
        initial_learning_rate,
        decay_steps=1000,
        decay_rate=0.97,
        staircase=True)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

    return model

def perform_cross_validation(filenames, labels, class_names, num_folds=5):
    skf = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)
    fold_no = 1
    results = []

    num_classes= len(class_names)

    for train_index, val_index in skf.split(filenames, labels):
        print(f"Training on fold {fold_no}...")

        train_files, val_files = filenames[train_index], filenames[val_index]
        train_labels, val_labels = labels[train_index], labels[val_index]

        # Create datasets for this fold
        train_dataset, validation_dataset = create_datasets_for_fold(train_files, train_labels, val_files, val_labels, class_names)

        # Build a new model for this fold
        model = build_model(num_classes)

        reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, min_lr=1e-6)
        early_stopping = EarlyStopping(monitor='val_accuracy', patience=7, restore_best_weights=True)
        checkpoint = ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True)

        # Train the model
        history = model.fit(
            train_dataset,
            epochs=EPOCHS,
            validation_data=validation_dataset,
            callbacks=[reduce_lr, early_stopping, checkpoint]
        )

        # Evaluate the model on the validation dataset
        loss, accuracy = model.evaluate(validation_dataset)
        print(f"Fold {fold_no} - Loss: {loss}, Accuracy: {accuracy}")

        results.append((loss, accuracy))
        fold_no += 1

    return results

def main():
    # Check data paths
    if not os.path.exists(TRAIN_DATA_PATH):
        raise Exception(f"Training data path does not exist: {TRAIN_DATA_PATH}")

    filenames, labels, class_names = get_filenames_and_labels(TRAIN_DATA_PATH)
    filenames = np.array(filenames)
    labels = np.array(labels)

    # Perform cross-validation
    results = perform_cross_validation(filenames, labels, class_names, num_folds=2)

    # Print the average performance across all folds
    avg_loss = np.mean([result[0] for result in results])
    avg_accuracy = np.mean([result[1] for result in results])
    print(f"Average Loss: {avg_loss}, Average Accuracy: {avg_accuracy}")

if __name__ == "__main__":
    main()

#Testing code
model = tf.keras.models.load_model('best_model.keras')
class_names = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']

def predict_all_images(data_directory):

    predictions = []
    true_labels = []

    for class_index, class_name in enumerate(class_names):
        class_path = os.path.join(data_directory, class_name)
        image_files = [os.path.join(class_path, fname) for fname in os.listdir(class_path)]

        for image_file in image_files:
            img = load_and_preprocess_image(image_file)
            if img is not None:
                img_array = tf.expand_dims(img, axis=0)
                prediction = model.predict(img_array)[0]
                predicted_class_index = np.argmax(prediction)
                predictions.append(predicted_class_index)
                true_labels.append(class_index)

    return true_labels, predictions

def plot_confusion_matrix(true_labels, predictions, class_names):
    cm = confusion_matrix(true_labels, predictions)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.show()

if __name__ == "__main__":
    data_directory = dataset_val
    true_labels, predictions = predict_all_images(data_directory)

    for true_label, predicted_label in zip(true_labels, predictions):
         print(f'True class: {class_names[true_label]}, Predicted class: {class_names[predicted_label]}')

    class_names = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
    plot_confusion_matrix(true_labels, predictions, class_names)